#!/usr/bin/env python

import math
import yaml
import rospy
import threading
import numpy as np
from interbotix_xs_msgs.msg import *
from interbotix_xs_msgs.srv import *
from urdf_parser_py.urdf import URDF
from sensor_msgs.msg import JointState

### @brief Class that simulates the actual InterbotixRobotXS class
### @details - allows to use the same higher level ROS packages or Python API code on a simulated robot with no change to the code
class InterbotixRobotXS(object):
    def __init__(self):
        self.rd = None                          # Holds the URDF loaded from the ROS parameter server
        self.timer_hz = 20                      # Rate [Hz] at which the ROS Timer (in charge of publishing joint states) should run
        self.js_topic = None                    # Joint States topic
        self.move_thresh = 300                  # Threshold [in milliseconds] above which desired goal positions should be split up into waypoints and simulated
        self.execute_joint_traj = False         # Boolean that is True if a trajectory is currently being executed; False otherwise
        self.joint_states = JointState()        # Current state of all joints
        self.cmd_mutex = threading.Lock()       # Mutex to prevent reading/writing to self.commands from different threads
        self.commands = {}                      # Holds the desired commands for each joint
        self.sleep_map = {}                     # Maps a joint name with its sleep position [rad]
        self.group_map = {}                     # Maps a group name with a JointGroup-type dictionary (refer to the top of the xs_sdk_obj.h file for reference)
        self.motor_map = {}                     # Maps a joint name with a MotorState-type dictionary (refer to the top of the xs_sdk_obj.h file for reference)
        self.gripper_map = {}                   # Maps a gripper name with a Gripper-type dictionary (refer to the top of the xs_sdk_obj.h file for reference)
        self.js_index_map = {}                  # Maps a joint name with its index in self.joint_states.position
        self.gripper_order = []                 # List of grippers as they appear in the 'joint_order' parameter in the Motor Configs YAML file
        self.get_urdf_info()
        self.robot_get_motor_configs()
        rospy.Service('torque_enable', TorqueEnable, self.robot_srv_torque_enable)
        rospy.Service('reboot_motors', Reboot, self.robot_srv_reboot_motors)
        rospy.Service('get_robot_info', RobotInfo, self.robot_srv_get_robot_info)
        rospy.Service('set_operating_modes', OperatingModes, self.robot_srv_set_operating_modes)
        rospy.Service('set_motor_pid_gains', MotorGains, self.robot_srv_set_motor_pid_gains)
        rospy.Service('set_motor_registers', RegisterValues, self.robot_srv_set_motor_registers)
        rospy.Service('get_motor_registers', RegisterValues, self.robot_srv_get_motor_registers)
        rospy.Subscriber('commands/joint_group', JointGroupCommand, self.robot_sub_command_group)
        rospy.Subscriber('commands/joint_single', JointSingleCommand, self.robot_sub_command_single)
        rospy.Subscriber('commands/joint_trajectory', JointTrajectoryCommand, self.robot_sub_command_traj)
        self.pub_joint_states = rospy.Publisher(self.js_topic, JointState, queue_size=1)
        rospy.Timer(rospy.Duration(1.0/self.timer_hz), self.robot_update_joint_states)

    ### @brief Loads the URDF to get joint limit information
    def get_urdf_info(self):
        robot_name = rospy.get_namespace().strip("/")
        full_rd_name = '/' + robot_name + '/robot_description'
        if rospy.has_param(full_rd_name):
            self.rd = URDF.from_parameter_server(key=full_rd_name)

    ### @brief Loads the 'motor_configs' and 'mode_configs' YAML files and populates various class dictionaries
    def robot_get_motor_configs(self):

        # get filepaths to the YAML files from the ROS parameter server
        motor_configs_filepath = rospy.get_param('~motor_configs')
        mode_configs_filepath = rospy.get_param('~mode_configs')

        # try to open and load both files into local variables
        try:
            with open(motor_configs_filepath, "r") as yamlfile:
                motor_configs = yaml.safe_load(yamlfile)
        except IOError:
            rospy.logerr("[xs_sdk] Motor Config File was not found.")
            return False

        try:
            with open(mode_configs_filepath, "r") as yamlfile:
                mode_configs = yaml.safe_load(yamlfile)
        except IOError:
            mode_configs = {}
            rospy.loginfo("[xs_sdk] Mode Config file was not found.")

        joint_order = motor_configs["joint_order"]
        sleep_positions = motor_configs["sleep_positions"]

        self.js_topic = motor_configs["joint_state_publisher"]["topic_name"]

        motor_groups = motor_configs.get("groups", {})
        grippers = motor_configs.get("grippers", {})
        motors = motor_configs["motors"]

        mode_groups = mode_configs.get("groups", {})
        singles = mode_configs.get("singles", {})

        # populate self.gripper_map
        for gpr, items in grippers.items():
            self.gripper_map[gpr] = {"horn_radius" : items["horn_radius"],
                                     "arm_length" : items["arm_length"],
                                     "left_finger" : items["left_finger"],
                                     "right_finger" : items["right_finger"]}

        # populate self.sleep_map, self.gripper_order, and self.js_index_map
        # also, initialize self.joint_states
        for indx in range(len(joint_order)):
            self.sleep_map[joint_order[indx]] = sleep_positions[indx]
            if joint_order[indx] in self.gripper_map:
                self.gripper_order.append(joint_order[indx])
                self.gripper_map[joint_order[indx]]["js_index"] = indx
            self.js_index_map[joint_order[indx]] = indx
            self.joint_states.name.append(joint_order[indx])
            self.joint_states.position.append(sleep_positions[indx])

        # continue to populate self.joint_states with gripper finger info
        for gpr in self.gripper_order:
            fingers = ["left_finger", "right_finger"]
            lin_dist = self.robot_convert_angular_position_to_linear(gpr, self.sleep_map[gpr])
            for finger in fingers:
                fngr = self.gripper_map[gpr][finger]
                self.js_index_map[fngr] = len(self.js_index_map)
                self.joint_states.name.append(fngr)
                self.joint_states.position.append(lin_dist if finger == "left_finger" else -lin_dist)

        # populate self.motor_map
        for name in joint_order:
            self.motor_map[name] = {"motor_id" : motors[name]["ID"]}

        # populate self.group_map
        groups = list(motor_groups)
        groups.insert(0, "all")
        motor_groups["all"] = joint_order
        mode_groups["all"] = {}
        for grp in groups:
            joint_names = motor_groups[grp]
            joint_group = {}
            joint_group["joint_names"] = joint_names
            joint_group["joint_num"] = len(joint_names)
            joint_group["joint_ids"] = [self.motor_map[name]["motor_id"] for name in joint_names]
            self.group_map[grp] = joint_group
            mode = mode_groups[grp].get("operating_mode", "position")
            profile_type = mode_groups[grp].get("profile_type", "velocity")
            profile_velocity = mode_groups[grp].get("profile_velocity", 0)
            profile_acceleration = mode_groups[grp].get("profile_acceleration", 0)
            self.robot_set_operating_modes("group", grp, mode, profile_type, profile_velocity, profile_acceleration)

        # continue to populate self.motor_map
        for sgl in singles:
            info = singles[sgl]
            mode = info.get("operating_mode", "position")
            profile_type = info.get("profile_type", "velocity")
            profile_velocity = info.get("profile_velocity", 0)
            profile_acceleration = info.get("profile_acceleration", 0)
            self.robot_set_joint_operating_mode(sgl, mode, profile_type, profile_velocity, profile_acceleration)

    ### @brief Set the operating mode for a specific group of motors or a single motor
    ### @param cmd_type - set to 'group' if changing the operating mode for a group of motors or 'single' if changing the operating mode for a single motor
    ### @param name - desired motor group name if cmd_type is set to 'group' or the desired motor name if cmd_type is set to 'single'
    ### @param mode - desired operating mode (either 'position', 'linear_position', 'ext_position', 'velocity', 'pwm', 'current', or 'current_based_position')
    ### @param profile_type - set to 'velocity' for a Velocity-based Profile or 'time' for a Time-based Profile (only 'time' is supported in the simulator)
    ### @param profile_velocity - passthrough to the Profile_Velocity register on the motor
    ### @param profile_acceleration - passthrough to the Profile_Acceleration register on the motor
    def robot_set_operating_modes(self, cmd_type, name, mode, profile_type, profile_velocity, profile_acceleration):
        if (cmd_type == "group" and name in self.group_map):
            for joint_name in self.group_map[name]["joint_names"]:
                self.robot_set_joint_operating_mode(joint_name, mode, profile_type, profile_velocity, profile_acceleration)
            self.group_map[name]["mode"] = mode
            self.group_map[name]["profile_type"] = profile_type
            rospy.loginfo("[xs_sdk] The operating mode for the '%s' group was changed to %s." % (name, mode))
        elif (cmd_type == "single" and name in self.motor_map):
            self.robot_set_joint_operating_mode(name, mode, profile_type, profile_velocity, profile_acceleration)
            rospy.loginfo("[xs_sdk] The operating mode for the '%s' joint was changed to %s." % (name, mode))
        elif ((cmd_type == "group" and name not in self.group_map) or (cmd_type == "single" and name not in self.motor_map)):
            rospy.loginfo("[xs_sdk] The '%s' joint/group does not exist. Was it added to the motor config file?" % name)
        else:
            rospy.logerr("[xs_sdk] Invalid command for argument 'cmd_type' while setting operating mode.")

    ### @brief Helper function used to set the operating mode for a single motor
    ### @param name - desired motor name
    ### @param mode - desired operating mode (either 'position', 'linear_position', 'ext_position', 'velocity', 'pwm', 'current', or 'current_based_position')
    ### @param profile_type - set to 'velocity' for a Velocity-based Profile or 'time' for a Time-based Profile (only 'time' is supported in the simulator)
    ### @param profile_velocity - passthrough to the Profile_Velocity register on the motor
    ### @param profile_acceleration - passthrough to the Profile_Acceleration register on the motor
    def robot_set_joint_operating_mode(self, name, mode, profile_type, profile_velocity, profile_acceleration):
        self.motor_map[name]["mode"] = mode
        self.motor_map[name]["profile_type"] = profile_type
        self.motor_map[name]["profile_velocity"] = profile_velocity
        self.motor_map[name]["profile_acceleration"] = profile_acceleration

    ### @brief Command a desired group of motors with the specified commands
    ### @param name - desired motor group name
    ### @param commands - vector of commands (order matches the order specified in the 'groups' section in the motor config file)
    ### @details - commands are processed differently based on the operating mode specified for the motor group
    def robot_write_commands(self, name, commands):
        joints = self.group_map[name]["joint_names"]
        for x in range(len(joints)):
            self.robot_write_joint_command(joints[x], commands[x])

    ### @brief Command a desired motor with the specified command
    ### @param name - desired motor name
    ### @param command - motor command
    ### @details - the command is processed differently based on the operating mode specified for the motor
    def robot_write_joint_command(self, name, command):
        motor = self.motor_map[name]
        mode = motor["mode"]
        prof_vel = motor["profile_velocity"]
        with self.cmd_mutex:
            # create a trajectory of 'waypoints' to execute if motion should take longer than 'self.move_thresh' milliseconds
            if "position" in mode and prof_vel > self.move_thresh:
                num_itr = int(round(prof_vel / 1000.0 * self.timer_hz + 1))
                self.commands[name] = list(np.linspace(command, self.joint_states.position[self.js_index_map[name]], num_itr))
            else:
                self.commands[name] = command

    ### @brief Converts a desired linear distance between two gripper fingers into an angular position
    ### @param name - name of the gripper servo to command
    ### @param linear_position - desired distance [m] between the two gripper fingers
    ### @return result - angular position [rad] that achieves the desired linear distance
    def robot_convert_linear_position_to_radian(self, name, linear_position):
        half_dist = linear_position / 2.0
        arm_length = self.gripper_map[name]["arm_length"]
        horn_radius = self.gripper_map[name]["horn_radius"]
        result = math.pi/2.0 - math.acos((horn_radius**2 + half_dist**2 - arm_length**2) / (2 * horn_radius * half_dist))
        return result

    ### @brief Converts a specified angular position into the linear distance from one gripper finger to the center of the gripper servo horn
    ### @param name - name of the gripper sevo to command
    ### @param angular_position - desired gripper angular position [rad]
    ### @return - linear position [m] from a gripper finger to the center of the gripper servo horn
    def robot_convert_angular_position_to_linear(self, name, angular_position):
        arm_length = self.gripper_map[name]["arm_length"]
        horn_radius = self.gripper_map[name]["horn_radius"]
        a1 = horn_radius * math.sin(angular_position)
        c = math.sqrt(horn_radius**2 - a1**2)
        a2 = math.sqrt(arm_length**2 - c**2)
        return a1 + a2

    ### @brief ROS Subscriber callback function to command a group of joints
    ### @param msg - JointGroupCommand message dictating the joint group to command along with the actual commands
    ### @details - refer to the message definition for details
    def robot_sub_command_group(self, msg):
        self.robot_write_commands(msg.name, msg.cmd)

    ### @brief ROS Subscriber callback function to command a single joint
    ### @param msg - JointSingleCommand message dictating the joint to command along with the actual command
    ### @details - refer to the message definition for details
    def robot_sub_command_single(self, msg):
        self.robot_write_joint_command(msg.name, msg.cmd)

    ### @brief ROS Subscriber callback function to command a joint trajectory
    ### @param msg - JointTrajectoryCommand message dictating the joint(s) to command along with the desired trajectory
    ### @details - refer to the message definition for details
    def robot_sub_command_traj(self, msg):

        # check to make sure there is no trajectory currently running and that it is valid
        if self.execute_joint_traj:
            rospy.loginfo("[xs_sdk] Trajectory rejected since joints are still moving.")
            return
        if len(msg.traj.points) < 2:
            rospy.loginfo("[xs_sdk] Trajectory has fewer than 2 points. Aborting...")
            return

        # get the mode and joint_names
        mode = None
        joint_names = []
        if (msg.cmd_type == "group"):
            mode = self.group_map[msg.name]["mode"]
            joint_names = self.group_map[msg.name]["joint_names"]
        elif (msg.cmd_type == "single"):
            mode = self.motor_map[msg.name]["mode"]
            joint_names.append(msg.name)

        # check to see if the initial positions match the current joint states
        if len(msg.traj.points[0].positions) == len(joint_names):
            for x in range(len(joint_names)):
                expected_state = msg.traj.points[0].positions[x]
                actual_state = self.joint_states.position[self.js_index_map[joint_names[x]]]
                if not (abs(expected_state - actual_state) < 0.01):
                    rospy.loginfo("[xs_sdk] The %s joint is not at the correct initial state." % joint_names[x])
                    rospy.loginfo("[xs_sdk] Expected state: %.2f, Actual State: %.2f." % (expected_state, actual_state))

        # execute the trajectory
        self.exectue_joint_traj = True

        points = msg.traj.points
        time_start = rospy.Time.now()
        for x in range(1, len(points)):
            if msg.cmd_type == "group":
                if "position" in mode:
                    self.robot_write_commands(msg.name, points[x].positions)
                elif mode == "velocity":
                    self.robot_write_commands(msg.name, points[x].velocities)
                elif mode == "pwm" or mode == "current":
                    self.robot_write_commands(msg.name, points[x].effort)
            elif msg.cmd_type == "single":
                if "position" in mode:
                    self.robot_write_joint_command(msg.name, points[x].positions[0])
                elif mode == "velocity":
                    self.robot_write_joint_command(msg.name, points[x].velocities[0])
                elif mode == "pwm" or mode == "current":
                    self.robot_write_joint_command(msg.name, points[x].effort[0])
            if x < len(points) - 1:
                period = (points[x].time_from_start - (rospy.Time.now() - time_start))
                rospy.sleep(period)

        self.exectue_joint_traj = False

    ### @brief ROS Service to 'simulate' torquing the joints on the robot on/off (doesn't actually do anything)
    ### @param req - TorqueEnable service message request
    ### @return <obj> - TorqueEnable service message response [unused]
    ### @details - refer to the service definition for details
    def robot_srv_torque_enable(self, req):
        if (req.cmd_type == "group"):
            if (req.enable): rospy.loginfo("[xs_sdk] The '%s' group was torqued on." % req.name)
            else: rospy.loginfo("[xs_sdk] The '%s' group was torqued off." % req.name)
        else:
            if (req.enable): rospy.loginfo("[xs_sdk] The '%s' joint was torqued on." % req.name)
            else: rospy.loginfo("[xs_sdk] The '%s' joint was torqued off." % req.name)
        return TorqueEnableResponse()

    ### @brief ROS Service to 'simulate' rebooting the motors on the robot (doesn't actually do anything)
    ### @param req - Reboot service message request
    ### @return <obj> - Reboot service message response [unused]
    ### @details - refer to the service definition for details
    def robot_srv_reboot_motors(self, req):
        if (req.cmd_type == "group"):
            rospy.loginfo("[xs_sdk] The '%s' group was rebooted." % req.name)
        else:
            rospy.loginfo("[xs_sdk] The '%s' joint was rebooted." % req.name)
        return RebootResponse()

    ### @brief ROS Service that allows the user to get information about the robot
    ### @param req - RobotInfo service message request
    ### @return <obj> - RobotInfo service message response
    ### @details - refer to the service definition for details
    def robot_srv_get_robot_info(self, req):
        res = RobotInfoResponse()
        joint_names = []
        if (req.cmd_type == "group"):
            joint_names = self.group_map[req.name]["joint_names"]
            res.profile_type = self.group_map[req.name]["profile_type"]
            res.mode = self.group_map[req.name]["mode"]
        elif (req.cmd_type == "single"):
            joint_names.append(req.name)
            res.profile_type = self.motor_map[req.name]["profile_type"]
            res.mode = self.motor_map[req.name]["mode"]

        res.num_joints = len(joint_names)

        for name in joint_names:
            res.joint_ids.append(self.motor_map[name]["motor_id"])
            if name in self.gripper_map:
                res.joint_sleep_positions.append(self.robot_convert_angular_position_to_linear(name, 0))
                name = self.gripper_map[name]["left_finger"]
                res.joint_names.append(name)
            else:
                res.joint_sleep_positions.append(self.sleep_map[name])
                res.joint_names.append(name)
            res.joint_state_indices.append(self.js_index_map[name])
            if self.rd is not None:
                joint_object = next((joint for joint in self.rd.joints if joint.name == name), None)
                res.joint_lower_limits.append(joint_object.limit.lower)
                res.joint_upper_limits.append(joint_object.limit.upper)
                res.joint_velocity_limits.append(joint_object.limit.velocity)
        if 'all' not in req.name:
            res.name.append(req.name)
        else:
            for key, _ in self.group_map.items():
                res.name.append(key)
        return res

    ### @brief ROS Service that allows the user to change operating modes
    ### @param req - OperatingModes service message request
    ### @return <obj> - OperatingModes service message response [unused]
    ### @details - refer to the service definition for details
    def robot_srv_set_operating_modes(self, req):
        self.robot_set_operating_modes(req.cmd_type, req.name, req.mode, req.profile_type, req.profile_velocity, req.profile_acceleration)
        return OperatingModesResponse()

    ### @brief ROS Service that allows the user to set the motor firmware PID gains (doesn't actually do anything)
    ### @param req - MotorGains service message request
    ### @return <obj> - MotorGains service message response [unused]
    ### @details - refer to the service defintion for details
    def robot_srv_set_motor_pid_gains(self, req):
        return MotorGainsResponse()

    ### @brief ROS Service that allows the user to change a specific register to a specific value for multiple motors
    ### @param req - RegisterValues service message request
    ### @return <obj> - RegisterValues service message response [unused]
    ### @details - refer to the service definition for details - only works with the 'Profile_Velocity' and
    ###            'Profile_Acceleration' registers; otherwise, it doesn't do anything
    def robot_srv_set_motor_registers(self, req):
        if req.reg == "Profile_Velocity" or req.reg == "Profile_Acceleration":
            reg = "profile_velocity" if req.reg == "Profile_Velocity" else "profile_acceleration"
            if req.cmd_type == "group":
                self.group_map[req.name][reg] = req.value
                for joint in self.group_map[req.name]["joint_names"]:
                    self.motor_map[joint][reg] = req.value
            elif req.cmd_type == "single":
                self.motor_map[req.name][reg] = req.value
        return RegisterValuesResponse()

    ### @brief ROS Service that allows the user to read a specific register on multiple motors
    ### @param req - RegisterValues service message request
    ### @return <obj> - RegisterValues service message response
    ### @details - refer to the service definition for details - only works with the 'Profile_Velocity' and
    ###            'Profile_Acceleration' registers; otherwise, it doesn't do anything
    def robot_srv_get_motor_registers(self, req):
        res = RegisterValuesResponse()
        if req.reg == "Profile_Velocity" or req.reg == "Profile_Acceleration":
            reg = "profile_velocity" if req.reg == "Profile_Velocity" else "profile_acceleration"
            if req.cmd_type == "group":
                for joint in self.group_map[req.name]["joint_names"]:
                    res.values.append(self.motor_map[joint][reg])
            elif req.cmd_type == "single":
                res.values.append(self.motor_map[req.name][reg])
        return res

    ### @brief ROS Timer that updates the joint states based on the current 'self.commands'
    ### @param event - TimerEvent message [unused]
    ### @details - if a joint is in any form of 'position' mode and its profile_velocity is greater than
    ###            'self.move_thresh' milliseconds, then 'self.commands[joint]' is a list of waypoints (in reverse);
    ###            that way, the latest position to execute can just be 'popped' off the end of the list - which is more efficient;
    ###            when the last element is taken out from the list, the dictionary key is thrown out. Otherwise, 'self.commands[joint]'
    ###            is a scalar. If the scalar is a position, the dictionary key is thrown out afterwards. Otherwise, the dictionary key is
    ###            not thrown out unless its value is 0 (which is what you want for 'pwm', 'current', or 'velocity' modes)
    def robot_update_joint_states(self, event):
        with self.cmd_mutex:
            for joint in self.commands.copy():
                mode = self.motor_map[joint]["mode"]

                if "position" in mode:
                    if type(self.commands[joint]) == list:
                        value = self.commands[joint].pop()
                        if len(self.commands[joint]) == 0:
                            del self.commands[joint]
                    else:
                        value = self.commands[joint]
                        del self.commands[joint]

                    if joint in self.gripper_map:
                        gpr = self.gripper_map[joint]
                        if mode == "linear_position":
                            angle = self.robot_convert_linear_position_to_radian(joint, value)
                            self.joint_states.position[self.js_index_map[joint]] = angle
                            self.joint_states.position[self.js_index_map[gpr["left_finger"]]] = value / 2.0
                            self.joint_states.position[self.js_index_map[gpr["right_finger"]]] = -value / 2.0
                        else:
                            lin_pos = self.robot_convert_angular_position_to_linear(joint, value)
                            self.joint_states.position[self.js_index_map[joint]] = value
                            self.joint_states.position[self.js_index_map[gpr["left_finger"]]] = lin_pos
                            self.joint_states.position[self.js_index_map[gpr["right_finger"]]] = -lin_pos
                    else:
                        self.joint_states.position[self.js_index_map[joint]] = value

                else:
                    value = self.commands[joint]
                    if value == 0:
                        del self.commands[joint]
                    else:
                        if mode == "velocity":
                            new_val = value / self.timer_hz
                        else:
                            # treat 'pwm' and 'current' equivalently
                            new_val = value / 2000.0
                        self.joint_states.position[self.js_index_map[joint]] += new_val
                        if joint in self.gripper_map:
                            gpr = self.gripper_map[joint]
                            angle = self.joint_states.position[self.js_index_map[joint]]
                            lin_pos = self.robot_convert_angular_position_to_linear(joint, angle)
                            self.joint_states.position[self.js_index_map[gpr["left_finger"]]] = lin_pos
                            self.joint_states.position[self.js_index_map[gpr["right_finger"]]] = -lin_pos

            self.joint_states.header.stamp = rospy.Time.now()
            self.pub_joint_states.publish(self.joint_states)

def main():
    rospy.init_node('xs_sdk_sim')
    InterbotixRobotXS()
    rospy.spin()

if __name__ == '__main__':
    main()
